Example: Calibration¶
This example shows how to calibrate a classifier through atom.
The data used is a variation on the Australian weather dataset from Kaggle. You can download it from here. The goal of this dataset is to predict whether or not it will rain tomorrow training a binary classifier on target RainTomorrow.
Load the data¶
In [1]:
Copied!
# Import packages
import pandas as pd
from atom import ATOMClassifier
# Import packages
import pandas as pd
from atom import ATOMClassifier
In [2]:
Copied!
# Load the data
X = pd.read_csv("docs_source/examples/datasets/weatherAUS.csv")
# Let's have a look
X.head()
# Load the data
X = pd.read_csv("docs_source/examples/datasets/weatherAUS.csv")
# Let's have a look
X.head()
Out[2]:
| Location | MinTemp | MaxTemp | Rainfall | Evaporation | Sunshine | WindGustDir | WindGustSpeed | WindDir9am | WindDir3pm | ... | Humidity9am | Humidity3pm | Pressure9am | Pressure3pm | Cloud9am | Cloud3pm | Temp9am | Temp3pm | RainToday | RainTomorrow | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | MelbourneAirport | 18.0 | 26.9 | 21.4 | 7.0 | 8.9 | SSE | 41.0 | W | SSE | ... | 95.0 | 54.0 | 1019.5 | 1017.0 | 8.0 | 5.0 | 18.5 | 26.0 | Yes | 0 |
| 1 | Adelaide | 17.2 | 23.4 | 0.0 | NaN | NaN | S | 41.0 | S | WSW | ... | 59.0 | 36.0 | 1015.7 | 1015.7 | NaN | NaN | 17.7 | 21.9 | No | 0 |
| 2 | Cairns | 18.6 | 24.6 | 7.4 | 3.0 | 6.1 | SSE | 54.0 | SSE | SE | ... | 78.0 | 57.0 | 1018.7 | 1016.6 | 3.0 | 3.0 | 20.8 | 24.1 | Yes | 0 |
| 3 | Portland | 13.6 | 16.8 | 4.2 | 1.2 | 0.0 | ESE | 39.0 | ESE | ESE | ... | 76.0 | 74.0 | 1021.4 | 1020.5 | 7.0 | 8.0 | 15.6 | 16.0 | Yes | 1 |
| 4 | Walpole | 16.4 | 19.9 | 0.0 | NaN | NaN | SE | 44.0 | SE | SE | ... | 78.0 | 70.0 | 1019.4 | 1018.9 | NaN | NaN | 17.4 | 18.1 | No | 0 |
5 rows × 22 columns
Run the pipeline¶
In [3]:
Copied!
atom = ATOMClassifier(X, "RainTomorrow", n_rows=1e4, verbose=1, warnings=False)
# Apply data cleaning steps
atom.clean()
atom.impute(strat_num="median", strat_cat="most_frequent")
atom.encode(strategy="target", max_onehot=5, infrequent_to_value=0.05)
# Train a linear SVM
atom.run("gnb")
atom = ATOMClassifier(X, "RainTomorrow", n_rows=1e4, verbose=1, warnings=False)
# Apply data cleaning steps
atom.clean()
atom.impute(strat_num="median", strat_cat="most_frequent")
atom.encode(strategy="target", max_onehot=5, infrequent_to_value=0.05)
# Train a linear SVM
atom.run("gnb")
<< ================== ATOM ================== >>
Algorithm task: binary classification.
Dataset stats ==================== >>
Shape: (10000, 22)
Train set size: 8000
Test set size: 2000
-------------------------------------
Memory: 4.34 MB
Scaled: False
Missing values: 22393 (10.2%)
Categorical features: 5 (23.8%)
Duplicate samples: 1 (0.0%)
Fitting Cleaner...
Cleaning the data...
Fitting Imputer...
Imputing missing values...
Location MinTemp MaxTemp Rainfall Evaporation Sunshine \
0 Adelaide 14.4 17.8 0.8 4.8 8.3
1 Bendigo 6.8 20.1 0.0 5.4 8.3
2 Sydney 21.3 26.6 0.0 7.8 8.4
3 Melbourne 14.6 33.6 0.8 6.6 13.6
4 Woomera 23.0 39.1 0.0 14.4 8.3
... ... ... ... ... ... ...
9995 Newcastle 12.1 26.5 0.0 4.8 8.3
9996 Ballarat 1.0 8.7 13.0 4.8 8.3
9997 SalmonGums 14.6 21.0 2.0 4.8 8.3
9998 Ballarat 5.4 14.9 0.0 4.8 8.3
9999 Dartmoor 9.0 14.2 0.0 1.8 2.4
WindGustDir WindGustSpeed WindDir9am WindDir3pm ... WindSpeed3pm \
0 SE 22.0 E SSE ... 11.0
1 SW 35.0 SSW SSW ... 17.0
2 NaN NaN NE NE ... 20.0
3 N 74.0 N N ... 39.0
4 SE 48.0 ESE SSE ... 17.0
... ... ... ... ... ... ...
9995 NaN NaN N NaN ... NaN
9996 W 46.0 WSW WSW ... 26.0
9997 E 24.0 E ESE ... 9.0
9998 NNE 61.0 NNE N ... 44.0
9999 NNE 41.0 NNE NNE ... 17.0
Humidity9am Humidity3pm Pressure9am Pressure3pm Cloud9am Cloud3pm \
0 62.0 59.0 1017.1 1017.5 NaN NaN
1 60.0 34.0 1019.9 1018.6 7.0 3.0
2 80.0 67.0 1017.5 1011.2 7.0 1.0
3 47.0 24.0 1018.0 1014.9 0.0 1.0
4 45.0 13.0 1015.7 1013.9 1.0 NaN
... ... ... ... ... ... ...
9995 81.0 NaN NaN NaN 0.0 NaN
9996 99.0 99.0 1008.1 1010.4 8.0 8.0
9997 75.0 56.0 NaN NaN NaN NaN
9998 75.0 48.0 1016.2 1011.9 5.0 1.0
9999 66.0 61.0 1019.8 1015.7 NaN NaN
Temp9am Temp3pm RainToday
0 16.4 15.6 No
1 11.8 19.0 No
2 22.3 24.8 No
3 23.7 31.3 No
4 27.1 37.5 No
... ... ... ...
9995 18.0 NaN No
9996 5.6 8.4 Yes
9997 18.1 20.4 Yes
9998 6.8 13.8 No
9999 10.0 13.8 No
[10000 rows x 21 columns] WindGustDir {'Location': SimpleImputer(strategy='most_frequent'), 'MinTemp': SimpleImputer(strategy='median'), 'MaxTemp': SimpleImputer(strategy='median'), 'Rainfall': SimpleImputer(strategy='median'), 'Evaporation': SimpleImputer(strategy='median'), 'Sunshine': SimpleImputer(strategy='median'), 'WindGustDir': SimpleImputer(strategy='most_frequent'), 'WindGustSpeed': SimpleImputer(strategy='median'), 'WindDir9am': SimpleImputer(strategy='most_frequent'), 'WindDir3pm': SimpleImputer(strategy='most_frequent'), 'WindSpeed9am': SimpleImputer(strategy='median'), 'WindSpeed3pm': SimpleImputer(strategy='median'), 'Humidity9am': SimpleImputer(strategy='median'), 'Humidity3pm': SimpleImputer(strategy='median'), 'Pressure9am': SimpleImputer(strategy='median'), 'Pressure3pm': SimpleImputer(strategy='median'), 'Cloud9am': SimpleImputer(strategy='median'), 'Cloud3pm': SimpleImputer(strategy='median'), 'Temp9am': SimpleImputer(strategy='median'), 'Temp3pm': SimpleImputer(strategy='median'), 'RainToday': SimpleImputer(strategy='most_frequent')}
Location MinTemp MaxTemp Rainfall Evaporation Sunshine \
0 Adelaide 14.4 17.8 0.8 4.8 8.3
1 Bendigo 6.8 20.1 0.0 5.4 8.3
2 Sydney 21.3 26.6 0.0 7.8 8.4
3 Melbourne 14.6 33.6 0.8 6.6 13.6
4 Woomera 23.0 39.1 0.0 14.4 8.3
... ... ... ... ... ... ...
9995 Newcastle 12.1 26.5 0.0 4.8 8.3
9996 Ballarat 1.0 8.7 13.0 4.8 8.3
9997 SalmonGums 14.6 21.0 2.0 4.8 8.3
9998 Ballarat 5.4 14.9 0.0 4.8 8.3
9999 Dartmoor 9.0 14.2 0.0 1.8 2.4
WindGustDir WindGustSpeed WindDir9am WindDir3pm ... WindSpeed3pm \
0 SE 22.0 E SSE ... 11.0
1 SW 35.0 SSW SSW ... 17.0
2 S 39.0 NE NE ... 20.0
3 N 74.0 N N ... 39.0
4 SE 48.0 ESE SSE ... 17.0
... ... ... ... ... ... ...
9995 S 39.0 N NaN ... NaN
9996 W 46.0 WSW WSW ... 26.0
9997 E 24.0 E ESE ... 9.0
9998 NNE 61.0 NNE N ... 44.0
9999 NNE 41.0 NNE NNE ... 17.0
Humidity9am Humidity3pm Pressure9am Pressure3pm Cloud9am Cloud3pm \
0 62.0 59.0 1017.1 1017.5 NaN NaN
1 60.0 34.0 1019.9 1018.6 7.0 3.0
2 80.0 67.0 1017.5 1011.2 7.0 1.0
3 47.0 24.0 1018.0 1014.9 0.0 1.0
4 45.0 13.0 1015.7 1013.9 1.0 NaN
... ... ... ... ... ... ...
9995 81.0 NaN NaN NaN 0.0 NaN
9996 99.0 99.0 1008.1 1010.4 8.0 8.0
9997 75.0 56.0 NaN NaN NaN NaN
9998 75.0 48.0 1016.2 1011.9 5.0 1.0
9999 66.0 61.0 1019.8 1015.7 NaN NaN
Temp9am Temp3pm RainToday
0 16.4 15.6 No
1 11.8 19.0 No
2 22.3 24.8 No
3 23.7 31.3 No
4 27.1 37.5 No
... ... ... ...
9995 18.0 NaN No
9996 5.6 8.4 Yes
9997 18.1 20.4 Yes
9998 6.8 13.8 No
9999 10.0 13.8 No
[10000 rows x 21 columns] WindDir9am {'Location': SimpleImputer(strategy='most_frequent'), 'MinTemp': SimpleImputer(strategy='median'), 'MaxTemp': SimpleImputer(strategy='median'), 'Rainfall': SimpleImputer(strategy='median'), 'Evaporation': SimpleImputer(strategy='median'), 'Sunshine': SimpleImputer(strategy='median'), 'WindGustDir': SimpleImputer(strategy='most_frequent'), 'WindGustSpeed': SimpleImputer(strategy='median'), 'WindDir9am': SimpleImputer(strategy='most_frequent'), 'WindDir3pm': SimpleImputer(strategy='most_frequent'), 'WindSpeed9am': SimpleImputer(strategy='median'), 'WindSpeed3pm': SimpleImputer(strategy='median'), 'Humidity9am': SimpleImputer(strategy='median'), 'Humidity3pm': SimpleImputer(strategy='median'), 'Pressure9am': SimpleImputer(strategy='median'), 'Pressure3pm': SimpleImputer(strategy='median'), 'Cloud9am': SimpleImputer(strategy='median'), 'Cloud3pm': SimpleImputer(strategy='median'), 'Temp9am': SimpleImputer(strategy='median'), 'Temp3pm': SimpleImputer(strategy='median'), 'RainToday': SimpleImputer(strategy='most_frequent')}
Location MinTemp MaxTemp Rainfall Evaporation Sunshine \
0 Adelaide 14.4 17.8 0.8 4.8 8.3
1 Bendigo 6.8 20.1 0.0 5.4 8.3
2 Sydney 21.3 26.6 0.0 7.8 8.4
3 Melbourne 14.6 33.6 0.8 6.6 13.6
4 Woomera 23.0 39.1 0.0 14.4 8.3
... ... ... ... ... ... ...
9995 Newcastle 12.1 26.5 0.0 4.8 8.3
9996 Ballarat 1.0 8.7 13.0 4.8 8.3
9997 SalmonGums 14.6 21.0 2.0 4.8 8.3
9998 Ballarat 5.4 14.9 0.0 4.8 8.3
9999 Dartmoor 9.0 14.2 0.0 1.8 2.4
WindGustDir WindGustSpeed WindDir9am WindDir3pm ... WindSpeed3pm \
0 SE 22.0 E SSE ... 11.0
1 SW 35.0 SSW SSW ... 17.0
2 S 39.0 NE NE ... 20.0
3 N 74.0 N N ... 39.0
4 SE 48.0 ESE SSE ... 17.0
... ... ... ... ... ... ...
9995 S 39.0 N NaN ... NaN
9996 W 46.0 WSW WSW ... 26.0
9997 E 24.0 E ESE ... 9.0
9998 NNE 61.0 NNE N ... 44.0
9999 NNE 41.0 NNE NNE ... 17.0
Humidity9am Humidity3pm Pressure9am Pressure3pm Cloud9am Cloud3pm \
0 62.0 59.0 1017.1 1017.5 NaN NaN
1 60.0 34.0 1019.9 1018.6 7.0 3.0
2 80.0 67.0 1017.5 1011.2 7.0 1.0
3 47.0 24.0 1018.0 1014.9 0.0 1.0
4 45.0 13.0 1015.7 1013.9 1.0 NaN
... ... ... ... ... ... ...
9995 81.0 NaN NaN NaN 0.0 NaN
9996 99.0 99.0 1008.1 1010.4 8.0 8.0
9997 75.0 56.0 NaN NaN NaN NaN
9998 75.0 48.0 1016.2 1011.9 5.0 1.0
9999 66.0 61.0 1019.8 1015.7 NaN NaN
Temp9am Temp3pm RainToday
0 16.4 15.6 No
1 11.8 19.0 No
2 22.3 24.8 No
3 23.7 31.3 No
4 27.1 37.5 No
... ... ... ...
9995 18.0 NaN No
9996 5.6 8.4 Yes
9997 18.1 20.4 Yes
9998 6.8 13.8 No
9999 10.0 13.8 No
[10000 rows x 21 columns] WindDir3pm {'Location': SimpleImputer(strategy='most_frequent'), 'MinTemp': SimpleImputer(strategy='median'), 'MaxTemp': SimpleImputer(strategy='median'), 'Rainfall': SimpleImputer(strategy='median'), 'Evaporation': SimpleImputer(strategy='median'), 'Sunshine': SimpleImputer(strategy='median'), 'WindGustDir': SimpleImputer(strategy='most_frequent'), 'WindGustSpeed': SimpleImputer(strategy='median'), 'WindDir9am': SimpleImputer(strategy='most_frequent'), 'WindDir3pm': SimpleImputer(strategy='most_frequent'), 'WindSpeed9am': SimpleImputer(strategy='median'), 'WindSpeed3pm': SimpleImputer(strategy='median'), 'Humidity9am': SimpleImputer(strategy='median'), 'Humidity3pm': SimpleImputer(strategy='median'), 'Pressure9am': SimpleImputer(strategy='median'), 'Pressure3pm': SimpleImputer(strategy='median'), 'Cloud9am': SimpleImputer(strategy='median'), 'Cloud3pm': SimpleImputer(strategy='median'), 'Temp9am': SimpleImputer(strategy='median'), 'Temp3pm': SimpleImputer(strategy='median'), 'RainToday': SimpleImputer(strategy='most_frequent')}
Location MinTemp MaxTemp Rainfall Evaporation Sunshine \
0 Adelaide 14.4 17.8 0.8 4.8 8.3
1 Bendigo 6.8 20.1 0.0 5.4 8.3
2 Sydney 21.3 26.6 0.0 7.8 8.4
3 Melbourne 14.6 33.6 0.8 6.6 13.6
4 Woomera 23.0 39.1 0.0 14.4 8.3
... ... ... ... ... ... ...
9995 Newcastle 12.1 26.5 0.0 4.8 8.3
9996 Ballarat 1.0 8.7 13.0 4.8 8.3
9997 SalmonGums 14.6 21.0 2.0 4.8 8.3
9998 Ballarat 5.4 14.9 0.0 4.8 8.3
9999 Dartmoor 9.0 14.2 0.0 1.8 2.4
WindGustDir WindGustSpeed WindDir9am WindDir3pm ... WindSpeed3pm \
0 SE 22.0 E SSE ... 11.0
1 SW 35.0 SSW SSW ... 17.0
2 S 39.0 NE NE ... 20.0
3 N 74.0 N N ... 39.0
4 SE 48.0 ESE SSE ... 17.0
... ... ... ... ... ... ...
9995 S 39.0 N SE ... 19.0
9996 W 46.0 WSW WSW ... 26.0
9997 E 24.0 E ESE ... 9.0
9998 NNE 61.0 NNE N ... 44.0
9999 NNE 41.0 NNE NNE ... 17.0
Humidity9am Humidity3pm Pressure9am Pressure3pm Cloud9am Cloud3pm \
0 62.0 59.0 1017.1 1017.5 5.0 5.0
1 60.0 34.0 1019.9 1018.6 7.0 3.0
2 80.0 67.0 1017.5 1011.2 7.0 1.0
3 47.0 24.0 1018.0 1014.9 0.0 1.0
4 45.0 13.0 1015.7 1013.9 1.0 5.0
... ... ... ... ... ... ...
9995 81.0 52.0 1017.5 1015.1 0.0 5.0
9996 99.0 99.0 1008.1 1010.4 8.0 8.0
9997 75.0 56.0 1017.5 1015.1 5.0 5.0
9998 75.0 48.0 1016.2 1011.9 5.0 1.0
9999 66.0 61.0 1019.8 1015.7 5.0 5.0
Temp9am Temp3pm RainToday
0 16.4 15.6 No
1 11.8 19.0 No
2 22.3 24.8 No
3 23.7 31.3 No
4 27.1 37.5 No
... ... ... ...
9995 18.0 21.2 No
9996 5.6 8.4 Yes
9997 18.1 20.4 Yes
9998 6.8 13.8 No
9999 10.0 13.8 No
[10000 rows x 21 columns] RainToday {'Location': SimpleImputer(strategy='most_frequent'), 'MinTemp': SimpleImputer(strategy='median'), 'MaxTemp': SimpleImputer(strategy='median'), 'Rainfall': SimpleImputer(strategy='median'), 'Evaporation': SimpleImputer(strategy='median'), 'Sunshine': SimpleImputer(strategy='median'), 'WindGustDir': SimpleImputer(strategy='most_frequent'), 'WindGustSpeed': SimpleImputer(strategy='median'), 'WindDir9am': SimpleImputer(strategy='most_frequent'), 'WindDir3pm': SimpleImputer(strategy='most_frequent'), 'WindSpeed9am': SimpleImputer(strategy='median'), 'WindSpeed3pm': SimpleImputer(strategy='median'), 'Humidity9am': SimpleImputer(strategy='median'), 'Humidity3pm': SimpleImputer(strategy='median'), 'Pressure9am': SimpleImputer(strategy='median'), 'Pressure3pm': SimpleImputer(strategy='median'), 'Cloud9am': SimpleImputer(strategy='median'), 'Cloud3pm': SimpleImputer(strategy='median'), 'Temp9am': SimpleImputer(strategy='median'), 'Temp3pm': SimpleImputer(strategy='median'), 'RainToday': SimpleImputer(strategy='most_frequent')}
Fitting Encoder...
Encoding categorical columns...
Training ========================= >>
Models: GNB
Metric: f1
Results for GaussianNB:
Fit ---------------------------------------------
Train evaluation --> f1: 0.5628
Test evaluation --> f1: 0.5965
Time elapsed: 0.031s
-------------------------------------------------
Total time: 0.031s
Final results ==================== >>
Total time: 0.033s
-------------------------------------
GaussianNB --> f1: 0.5965
Analyze the results¶
In [4]:
Copied!
# Check the model's calibration
atom.plot_calibration()
# Check the model's calibration
atom.plot_calibration()
In [5]:
Copied!
# Let's try to improve it using the calibrate method
atom.winner.calibrate(method="isotonic", cv=5)
# Let's try to improve it using the calibrate method
atom.winner.calibrate(method="isotonic", cv=5)
Results for GaussianNB: Fit --------------------------------------------- Train evaluation --> f1: 0.4869 Test evaluation --> f1: 0.5254 Time elapsed: 0.136s
In [6]:
Copied!
# And check again...
atom.plot_calibration()
# And check again...
atom.plot_calibration()